#pragma once
#include <flashinfer/page.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/attention/default_prefill_params.cuh>

using namespace flashinfer;

using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ idtype }};
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }};
constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }};
constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }};
constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }};

constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }};
constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }};
constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }};

constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
constexpr bool USE_LOGITS_SOFT_CAP = false;

using PrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;

#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK,    \
            USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...)   \
  DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, {                                    \
    DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, {                                  \
      __VA_ARGS__();                                                                \
    });                                                                             \
});
